iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 26
0
Google Developers Machine Learning

Towards Tensorflow 2.0系列 第 26

[Day-26] 生成對抗網路 (GAN) 實作 Part I

  • 分享至 

  • xImage
  •  

今天我們來實作GAN,簡單複習一下,GAN的Component 有 Generator 以及 Discriminiator
而 Generator 任務就是產生圖片來騙過Discriminator , Discriminator的任務就是努力判斷 Generator 所產生圖片的品質。因此,"對抗" 的概念就是從此而來。在實作上,相較於之前的AutoEncoder Model,GAN大部分都是使用 Convolution layer (卷積層) ,而非像之前其他有許多的Dense layer (全連接層)。下圖就是簡易的 Generator 以及 Discriminator 的架構,GAN在training的時候非常不穩定,因此一些layer的設定或者Actvation function的選擇都要注意。

https://ithelp.ithome.com.tw/upload/images/20191011/20119971Ksl9dbMMSN.png
source

Discriminator Network:

其實 Discriminator 蠻直觀的,其實他就是一個圖片分類器,用以判斷 Generator 產生圖片的好壞。因此,我們可以來看一下如何定義 Discriminator,可以直接先透過Call function來看前向傳播的架構。

就是透過 Conv -> BN -> Con -> BN ... -> Falltern -> Dense ,注意Activation fucntion的部分都是使用Leaky Relu (Leacky Relu跟Relu最大差別就是當值小於0的時候的差別,Relu只要小於0均為0,Leaky Relu則仍會有值)。

https://ithelp.ithome.com.tw/upload/images/20191011/20119971RiqU5G1NE9.png
Relu vs Leaky Relu

class Discriminator(keras.Model):
  def __init__(self):
    super(Discriminator,self).__init__()

    self.conv_1 = layers.Conv2D(64,5,3,'valid')
    self.conv_2 = layers.Conv2D(128,5,3,'valid')
    self.bn_1 = layers.BatchNormalization()
    self.conv_3 = layers.Conv2D(256,5,3,'valid')    
    self.bn_2 = layers.BatchNormalization()
    self.flatten = layers.Flatten()
    self.fc_layer = layers.Dense(1)

  
  def call(self, inputs, training=None):
    x = tf.nn.leaky_relu(self.conv_1(inputs))    
    x = tf.nn.leaky_relu(self.bn_1(self.conv_2(x),training=training))    
    x = tf.nn.leaky_relu(self.bn_2(self.conv_3(x),training=training))  
    x = self.flatten(x)
    x = self.fc_layer(x)
    return x

Generator Network:

Generator的部分,主要為一個圖片產生器,透過一個低維度的matrix,還原成一張正常的圖片。在Generator中
,會使用 tf.layers.Conv2DTranspose (反卷積) ,簡單來說就是把特徵還原成圖片的概念 (如下圖)

https://ithelp.ithome.com.tw/upload/images/20191011/201199712GCB6ooo5d.png

接下來,可以直接先透過Call function來看前向傳播的架構。

Input -> Dense -> Conv Transpose -> BN -> .. -> Tanh

class Generator(keras.Model):
  def __init__(self):
    super(Generator,self).__init__()
    #encoder
    self.fc_layer_1 = layers.Dense(3*3*512)
    self.conv_1 = layers.Conv2DTranspose(256,3,3,'valid')

    self.bn_1 = layers.BatchNormalization()       
    self.conv_2 = layers.Conv2DTranspose(128,5,2,'valid')
    self.bn_2 = layers.BatchNormalization()     
    self.conv_3 = layers.Conv2DTranspose(3,4,3,'valid')

  def call(self, inputs, training=None):
    x = self.fc_layer_1(inputs)
    x = tf.reshape(x,[-1,3,3,512])
    x = tf.nn.leaky_relu(x)
    x = self.bn_1(self.conv_1(x),training=training)
    x = self.bn_2(self.conv_2(x),training=training)
    x = self.conv_3(x)
    x = tf.tanh(x)
    return x

我們就完成 Generator 和 Discriminator 的建置。接下來就可以做簡單的測試。

x = tf.random.normal([1,64,64,3])
z = tf.random.normal([1,100])
prob = g(x)
print(prob)
out = d(x)
print(out.shape)

https://ithelp.ithome.com.tw/upload/images/20191011/201199711RLhr5phhl.png

小結:

今天完成簡易的GAN 模型與建立,明天會跑真實的資料! 感謝大家漫長閱讀。祝大家連假愉快

一日一梗圖:

https://ithelp.ithome.com.tw/upload/images/20191011/20119971x3jdspG9pT.png
source

Reference

GAN_example


上一篇
[Day-25] 生成對抗網路 (GAN) 介紹
下一篇
[Day-27] 生成對抗網路 (GAN) 實作 Part II
系列文
Towards Tensorflow 2.030
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言